{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pandas import DataFrame\n",
    "from py2neo import Graph,Node,Relationship,NodeMatcher\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import os\n",
    "# 连接Neo4j数据库\n",
    "graph = Graph('http://localhost:7474/db/data/',username='neo4j',password='123456')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<py2neo.database.Cursor at 0x29e10ea2588>"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "a = Node('Person',name='Tom')\n",
    "graph.create(a)\n",
    "b = Node('Person',name='Bob')\n",
    "graph.create(b)\n",
    "\n",
    "# 创建关系例子\n",
    "r = Relationship(a,'KNOWS',b)\n",
    "graph.create(r)\n",
    "\n",
    "# 读取节点信息\n",
    "node = DataFrame(graph.run('MATCH (n:`Person`) RETURN n LIMIT 25'))\n",
    "# print(node)\n",
    "\n",
    "# 读取关系信息\n",
    "relation = DataFrame(graph.run('MATCH (n:`Person`)-[r]->(m:`Person`) return n,m,type(r)'))\n",
    "# print(relation)\n",
    "\n",
    "# 删除所有节点\n",
    "graph.run('MATCH (n) OPTIONAL MATCH (n)-[r]-() DELETE n,r')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 读取数据\n",
    "stock = pd.read_csv('stock_basic.csv',encoding=\"gbk\")\n",
    "holder = pd.read_csv('holders.csv')\n",
    "concept_num = pd.read_csv('concept.csv')\n",
    "concept = pd.read_csv('stock_concept.csv')\n",
    "sh = pd.read_csv('sh.csv')\n",
    "sz = pd.read_csv('sz.csv')\n",
    "corr = pd.read_csv('corr.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 数据预处理\n",
    "stock['行业'] = stock['行业'].fillna('未知')\n",
    "holder = holder.drop_duplicates(subset=None, keep='first', inplace=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 创建实体（概念、股票、股东、股通）\n",
    "\n",
    "sz = Node('深股通',名字='深股通')\n",
    "graph.create(sz)  \n",
    " \n",
    "sh = Node('沪股通',名字='沪股通')\n",
    "graph.create(sh)  \n",
    "\n",
    "for i in concept_num.values:\n",
    "    a = Node('概念',概念代码=i[1],概念名称=i[2])\n",
    "    print('概念代码:'+str(i[1]),'概念名称:'+str(i[2]))\n",
    "    graph.create(a)\n",
    "\n",
    "for i in stock.values:\n",
    "    a = Node('股票',TS代码=i[1],股票名称=i[3],行业=i[4])\n",
    "    print('TS代码:'+str(i[1]),'股票名称:'+str(i[3]),'行业:'+str(i[4]))\n",
    "    graph.create(a)\n",
    "\n",
    "for i in holder.values:\n",
    "    a = Node('股东',TS代码=i[0],股东名称=i[1],持股数量=i[2],持股比例=i[3])\n",
    "    print('TS代码:'+str(i[0]),'股东名称:'+str(i[1]),'持股数量:'+str(i[2]))\n",
    "    graph.create(a)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 创建关系（股票-股东、股票-概念、股票-公告、股票-股通）\n",
    "\n",
    "matcher = NodeMatcher(graph) \n",
    "for i in holder.values:    \n",
    "    a = matcher.match(\"股票\",TS代码=i[0]).first()\n",
    "    b = matcher.match(\"股东\",TS代码=i[0])\n",
    "    for j in b:\n",
    "        r = Relationship(j,'参股',a)\n",
    "        graph.create(r)\n",
    "        print('TS',str(i[0]))\n",
    "            \n",
    "for i in concept.values:\n",
    "    a = matcher.match(\"股票\",TS代码=i[3]).first()\n",
    "    b = matcher.match(\"概念\",概念代码=i[1]).first()\n",
    "    if a == None or b == None:\n",
    "        continue\n",
    "    r = Relationship(a,'概念属于',b)\n",
    "    graph.create(r) \n",
    "\n",
    "noticesdir = os.listdir(\"notices\\\\\")\n",
    "for n in noticesdir:\n",
    "    notice = pd.read_csv(\"notices\\\\\"+n,encoding=\"utf_8_sig\")\n",
    "    notice['content'] = notice['content'].fillna('空白')\n",
    "    for i in notice.values:\n",
    "        a = matcher.match(\"股票\",TS代码=i[0]).first()\n",
    "        b = Node('公告',日期=i[1],标题=i[2],内容=i[3])\n",
    "        graph.create(b)\n",
    "        r = Relationship(a,'发布公告',b)\n",
    "        graph.create(r)\n",
    "        print(str(i[0]))\n",
    "        \n",
    "for i in sz.values:\n",
    "    a = matcher.match(\"股票\",TS代码=i[0]).first()\n",
    "    b = matcher.match(\"深股通\").first()\n",
    "    r = Relationship(a,'成分股属于',b)\n",
    "    graph.create(r)\n",
    "    print('TS代码:'+str(i[1]),'--深股通')\n",
    "\n",
    "for i in sh.values:\n",
    "    a = matcher.match(\"股票\",TS代码=i[0]).first()\n",
    "    b = matcher.match(\"沪股通\").first()\n",
    "    r = Relationship(a,'成分股属于',b)\n",
    "    graph.create(r)\n",
    "    print('TS代码:'+str(i[1]),'--沪股通')\n",
    "\n",
    "# 构建股票间关联\n",
    "corr = pd.read_csv(\"corr.csv\")\n",
    "for i in corr.values:\n",
    "    a = matcher.match(\"股票\",TS代码=i[1][:-1]).first()\n",
    "    b = matcher.match(\"股票\",TS代码=i[2][:-1]).first()\n",
    "    r = Relationship(a,str(i[3]),b)\n",
    "    graph.create(r)\n",
    "    print(i)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
